blob: 8254ce3126e3fa3fbc26a9a33c82d76ea735cd72 [file] [log] [blame]
#include <c10/util/intrusive_ptr.h>
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#include <torch/csrc/distributed/c10d/Types.hpp>
#include <torch/library.h>
namespace c10d {
namespace ops {
// Below are ProcessGroup's corresponding ops for each backend. Ops are but
// routed through the dispatcher to be dispatched to the appropriate backend.
// Currently a no-op as the process group does not have a list of backends.
c10::intrusive_ptr<Work> send_cpu(
at::TensorList tensors,
const c10::intrusive_ptr<ProcessGroup>& process_group,
int64_t dstRank,
int64_t tag) {
auto tensor_vec = tensors.vec();
return process_group->send(
tensor_vec, static_cast<int>(dstRank), static_cast<int>(tag));
}
c10::intrusive_ptr<Work> send_cuda(
at::TensorList tensors,
const c10::intrusive_ptr<ProcessGroup>& process_group,
int64_t dstRank,
int64_t tag) {
auto tensor_vec = tensors.vec();
return process_group->send(
tensor_vec, static_cast<int>(dstRank), static_cast<int>(tag));
}
c10::intrusive_ptr<Work> recv_cpu_(
at::TensorList tensors,
const c10::intrusive_ptr<ProcessGroup>& process_group,
int64_t srcRank,
int64_t tag) {
auto tensor_vec = tensors.vec();
return process_group->recv(
tensor_vec, static_cast<int>(srcRank), static_cast<int>(tag));
}
c10::intrusive_ptr<Work> recv_cuda_(
at::TensorList tensors,
const c10::intrusive_ptr<ProcessGroup>& process_group,
int64_t srcRank,
int64_t tag) {
auto tensor_vec = tensors.vec();
return process_group->recv(
tensor_vec, static_cast<int>(srcRank), static_cast<int>(tag));
}
c10::intrusive_ptr<Work> reduce_cpu_(
at::TensorList tensors,
const c10::intrusive_ptr<ProcessGroup>& process_group,
const c10::intrusive_ptr<ReduceOp>& reduce_op,
int64_t root_rank,
int64_t root_tensor,
int64_t timeout) {
auto tensor_vec = tensors.vec();
return process_group->reduce(
tensor_vec,
ReduceOptions{
*reduce_op.get(),
root_rank,
root_tensor,
std::chrono::milliseconds(timeout)});
}
c10::intrusive_ptr<Work> reduce_cuda_(
at::TensorList tensors,
const c10::intrusive_ptr<ProcessGroup>& process_group,
const c10::intrusive_ptr<ReduceOp>& reduce_op,
int64_t root_rank,
int64_t root_tensor,
int64_t timeout) {
auto tensor_vec = tensors.vec();
return process_group->reduce(
tensor_vec,
ReduceOptions{
*reduce_op.get(),
root_rank,
root_tensor,
std::chrono::milliseconds(timeout)});
}
std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> broadcast_cpu_(
at::TensorList tensors,
const c10::intrusive_ptr<ProcessGroup>& process_group,
int64_t root_rank,
int64_t root_tensor,
int64_t timeout) {
auto tensor_vec = tensors.vec();
auto work = process_group->broadcast(
tensor_vec,
BroadcastOptions{
root_rank, root_tensor, std::chrono::milliseconds(timeout)});
return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(
std::move(tensor_vec), work);
}
std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> broadcast_cuda_(
at::TensorList tensors,
const c10::intrusive_ptr<ProcessGroup>& process_group,
int64_t root_rank,
int64_t root_tensor,
int64_t timeout) {
auto tensor_vec = tensors.vec();
auto work = process_group->broadcast(
tensor_vec,
BroadcastOptions{
root_rank, root_tensor, std::chrono::milliseconds(timeout)});
return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(
std::move(tensor_vec), work);
}
std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> allreduce_cpu_(
at::TensorList tensors,
const c10::intrusive_ptr<ProcessGroup>& process_group,
const c10::intrusive_ptr<ReduceOp>& reduce_op,
int64_t timeout) {
auto tensor_vec = tensors.vec();
auto work = process_group->allreduce(
tensor_vec,
AllreduceOptions{*reduce_op.get(), std::chrono::milliseconds(timeout)});
// Return input tensors as output tensors to make inplace allreduce look like
// a functional API, so that make_fx can correctly build the dependencies in
// the graph later.
return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(
std::move(tensor_vec), work);
}
std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> allreduce_cuda_(
at::TensorList tensors,
const c10::intrusive_ptr<ProcessGroup>& process_group,
const c10::intrusive_ptr<ReduceOp>& reduce_op,
int64_t timeout) {
auto tensor_vec = tensors.vec();
auto work = process_group->allreduce(
tensor_vec,
AllreduceOptions{*reduce_op.get(), std::chrono::milliseconds(timeout)});
// Return input tensors as output tensors to make inplace allreduce look like
// a functional API, so that make_fx can correctly build the dependencies in
// the graph later.
return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(
std::move(tensor_vec), work);
}
std::tuple<std::vector<std::vector<at::Tensor>>, c10::intrusive_ptr<Work>>
allgather_cpu_(
const std::vector<std::vector<at::Tensor>>& output_tensors,
const std::vector<at::Tensor>& input_tensors,
const c10::intrusive_ptr<ProcessGroup>& process_group,
int64_t timeout) {
auto work = process_group->allgather(
const_cast<std::vector<std::vector<at::Tensor>>&>(output_tensors),
const_cast<std::vector<at::Tensor>&>(input_tensors),
AllgatherOptions{std::chrono::milliseconds(timeout)});
// Copy output tensors (not storage) so that this can be used in a functional
// manner
return std::
tuple<std::vector<std::vector<at::Tensor>>, c10::intrusive_ptr<Work>>(
output_tensors, work);
}
std::tuple<std::vector<std::vector<at::Tensor>>, c10::intrusive_ptr<Work>>
allgather_cuda_(
const std::vector<std::vector<at::Tensor>>& output_tensors,
const std::vector<at::Tensor>& input_tensors,
const c10::intrusive_ptr<ProcessGroup>& process_group,
int64_t timeout) {
auto work = process_group->allgather(
const_cast<std::vector<std::vector<at::Tensor>>&>(output_tensors),
const_cast<std::vector<at::Tensor>&>(input_tensors),
AllgatherOptions{std::chrono::milliseconds(timeout)});
// Copy output tensors (not storage) so that this can be used in a functional
// manner
return std::
tuple<std::vector<std::vector<at::Tensor>>, c10::intrusive_ptr<Work>>(
output_tensors, work);
}
// register functions to dispatcher
namespace {
TORCH_LIBRARY_IMPL(c10d, CPU, m) {
m.impl("send", send_cpu);
}
TORCH_LIBRARY_IMPL(c10d, CUDA, m) {
m.impl("send", send_cuda);
}
TORCH_LIBRARY_IMPL(c10d, CPU, m) {
m.impl("recv_", recv_cpu_);
}
TORCH_LIBRARY_IMPL(c10d, CUDA, m) {
m.impl("recv_", recv_cuda_);
}
TORCH_LIBRARY_IMPL(c10d, CPU, m) {
m.impl("reduce_", reduce_cpu_);
}
TORCH_LIBRARY_IMPL(c10d, CUDA, m) {
m.impl("reduce_", reduce_cuda_);
}
TORCH_LIBRARY_IMPL(c10d, CPU, m) {
m.impl("broadcast_", broadcast_cpu_);
}
TORCH_LIBRARY_IMPL(c10d, CUDA, m) {
m.impl("broadcast_", broadcast_cuda_);
}
TORCH_LIBRARY_IMPL(c10d, CPU, m) {
m.impl("allreduce_", allreduce_cpu_);
}
// TODO: The SparseCPU/SparseCUDA dispatched methods are only used to support
// sparse all_reduce in the Gloo backend
TORCH_LIBRARY_IMPL(c10d, SparseCPU, m) {
m.impl("allreduce_", allreduce_cpu_);
}
TORCH_LIBRARY_IMPL(c10d, SparseCUDA, m) {
m.impl("allreduce_", allreduce_cuda_);
}
TORCH_LIBRARY_IMPL(c10d, CUDA, m) {
m.impl("allreduce_", allreduce_cuda_);
}
TORCH_LIBRARY_IMPL(c10d, CPU, m) {
m.impl("allgather_", allgather_cpu_);
}
TORCH_LIBRARY_IMPL(c10d, CUDA, m) {
m.impl("allgather_", allgather_cuda_);
}
} // namespace
} // namespace ops
} // namespace c10d