blob: c74a85e2e2e21a0ed2ffb0c2185f2b1d6785b584 [file] [log] [blame]
#include <torch/csrc/distributed/c10d/Ops.hpp>
#include <ATen/core/dispatch/Dispatcher.h>
#include <torch/library.h>
namespace c10d {
namespace {
c10::intrusive_ptr<ProcessGroup::Work> broadcast_(
at::TensorList tensors,
const c10::intrusive_ptr<ProcessGroup>& process_group,
int64_t root_rank,
int64_t root_tensor,
int64_t timeout) {
auto tensor_vec = tensors.vec();
return process_group->broadcast(
tensor_vec,
BroadcastOptions{
root_rank, root_tensor, std::chrono::milliseconds(timeout)});
}
c10::intrusive_ptr<ProcessGroup::Work> allreduce_(
at::TensorList tensors,
const c10::intrusive_ptr<ProcessGroup>& process_group,
int64_t reduce_op,
int64_t timeout) {
auto tensor_vec = tensors.vec();
return process_group->allreduce(
tensor_vec,
AllreduceOptions{
static_cast<ReduceOp>(reduce_op),
std::chrono::milliseconds(timeout)});
}
c10::intrusive_ptr<ProcessGroup::Work> allgather_(
const std::vector<std::vector<at::Tensor>>& output_tensors,
const std::vector<at::Tensor>& input_tensors,
const c10::intrusive_ptr<ProcessGroup>& process_group,
int64_t timeout) {
return process_group->allgather(
const_cast<std::vector<std::vector<at::Tensor>>&>(output_tensors),
const_cast<std::vector<at::Tensor>&>(input_tensors),
AllgatherOptions{std::chrono::milliseconds(timeout)});
}
c10::intrusive_ptr<ProcessGroup::Work> reduce_scatter_(
const std::vector<at::Tensor>& output_tensors,
const std::vector<std::vector<at::Tensor>>& input_tensors,
const c10::intrusive_ptr<ProcessGroup>& process_group,
int64_t reduce_op,
int64_t timeout) {
return process_group->reduce_scatter(
const_cast<std::vector<at::Tensor>&>(output_tensors),
const_cast<std::vector<std::vector<at::Tensor>>&>(input_tensors),
ReduceScatterOptions{
static_cast<ReduceOp>(reduce_op),
std::chrono::milliseconds(timeout)});
}
c10::intrusive_ptr<ProcessGroup::Work> reduce_(
at::TensorList tensors,
const c10::intrusive_ptr<ProcessGroup>& process_group,
int64_t reduce_op,
int64_t root_rank,
int64_t root_tensor,
int64_t timeout) {
auto tensor_vec = tensors.vec();
return process_group->reduce(
tensor_vec,
ReduceOptions{
static_cast<ReduceOp>(reduce_op),
root_rank,
root_tensor,
std::chrono::milliseconds(timeout)});
}
c10::intrusive_ptr<ProcessGroup::Work> gather_(
const std::vector<std::vector<at::Tensor>>& output_tensors,
const std::vector<at::Tensor>& input_tensors,
const c10::intrusive_ptr<ProcessGroup>& process_group,
int64_t root_rank,
int64_t timeout) {
return process_group->gather(
const_cast<std::vector<std::vector<at::Tensor>>&>(output_tensors),
const_cast<std::vector<at::Tensor>&>(input_tensors),
GatherOptions{root_rank, std::chrono::milliseconds(timeout)});
}
TORCH_LIBRARY(c10d, m) {
// The following ProcessGroup and Work definations are more like declarations.
// They don't expose the details of the two classes into TorchScript.
m.class_<ProcessGroup>("ProcessGroup").def(torch::init<int64_t, int64_t>());
m.class_<ProcessGroup::Work>("Work").def(torch::init<>());
// It's important to register the op to the CompositeExplicitAutograd key to
// enable
// __torch_dispatch__.
m.def(
"broadcast_",
dispatch(c10::DispatchKey::CompositeExplicitAutograd, broadcast_));
m.def(
"allreduce_",
dispatch(c10::DispatchKey::CompositeExplicitAutograd, allreduce_));
m.def(
"allgather_",
dispatch(c10::DispatchKey::CompositeExplicitAutograd, allgather_));
m.def(
"reduce_scatter_",
dispatch(c10::DispatchKey::CompositeExplicitAutograd, reduce_scatter_));
m.def(
"reduce_",
dispatch(c10::DispatchKey::CompositeExplicitAutograd, reduce_));
m.def(
"gather_",
dispatch(c10::DispatchKey::CompositeExplicitAutograd, gather_));
}
} // namespace
namespace ops {
c10::intrusive_ptr<ProcessGroup::Work> broadcast(
const c10::intrusive_ptr<ProcessGroup>& process_group,
at::TensorList tensors,
const BroadcastOptions& opts) {
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("c10d::broadcast_", "")
.typed<c10::intrusive_ptr<::c10d::ProcessGroup::Work>(
at::TensorList,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
int64_t,
int64_t,
int64_t)>();
// It's awakward to unbox the opts here and box them again in the custom C++
// op. But it's also complicated to make opts as a CustomClassHolder. Leave it
// as it is now.
return op.call(
tensors,
process_group,
opts.rootRank,
opts.rootTensor,
opts.timeout.count());
}
c10::intrusive_ptr<ProcessGroup::Work> allreduce(
const c10::intrusive_ptr<ProcessGroup>& process_group,
at::TensorList tensors,
const AllreduceOptions& opts) {
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("c10d::allreduce_", "")
.typed<c10::intrusive_ptr<::c10d::ProcessGroup::Work>(
at::TensorList,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
int64_t,
int64_t)>();
return op.call(
tensors,
process_group,
static_cast<uint64_t>(opts.reduceOp),
opts.timeout.count());
}
c10::intrusive_ptr<ProcessGroup::Work> allgather(
const c10::intrusive_ptr<ProcessGroup>& process_group,
const std::vector<std::vector<at::Tensor>>& output_tensors,
const std::vector<at::Tensor>& input_tensors,
const AllgatherOptions& opts) {
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("c10d::allgather_", "")
.typed<c10::intrusive_ptr<::c10d::ProcessGroup::Work>(
const std::vector<std::vector<at::Tensor>>&,
const std::vector<at::Tensor>&,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
int64_t)>();
return op.call(
output_tensors, input_tensors, process_group, opts.timeout.count());
}
c10::intrusive_ptr<ProcessGroup::Work> reduce_scatter(
const c10::intrusive_ptr<ProcessGroup>& process_group,
const std::vector<at::Tensor>& output_tensors,
const std::vector<std::vector<at::Tensor>>& input_tensors,
const ReduceScatterOptions& opts) {
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("c10d::reduce_scatter_", "")
.typed<c10::intrusive_ptr<::c10d::ProcessGroup::Work>(
const std::vector<at::Tensor>&,
const std::vector<std::vector<at::Tensor>>&,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
int64_t,
int64_t)>();
return op.call(
output_tensors,
input_tensors,
process_group,
static_cast<uint64_t>(opts.reduceOp),
opts.timeout.count());
}
c10::intrusive_ptr<ProcessGroup::Work> reduce(
const c10::intrusive_ptr<ProcessGroup>& process_group,
at::TensorList tensors,
const ReduceOptions& opts) {
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("c10d::reduce_", "")
.typed<c10::intrusive_ptr<::c10d::ProcessGroup::Work>(
at::TensorList,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
int64_t,
int64_t,
int64_t,
int64_t)>();
return op.call(
tensors,
process_group,
static_cast<uint64_t>(opts.reduceOp),
opts.rootRank,
opts.rootTensor,
opts.timeout.count());
}
c10::intrusive_ptr<ProcessGroup::Work> gather(
const c10::intrusive_ptr<ProcessGroup>& process_group,
const std::vector<std::vector<at::Tensor>>& output_tensors,
const std::vector<at::Tensor>& input_tensors,
const GatherOptions& opts) {
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("c10d::gather_", "")
.typed<c10::intrusive_ptr<::c10d::ProcessGroup::Work>(
const std::vector<std::vector<at::Tensor>>&,
const std::vector<at::Tensor>&,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
int64_t,
int64_t)>();
return op.call(
output_tensors,
input_tensors,
process_group,
opts.rootRank,
opts.timeout.count());
}
} // namespace ops
} // namespace c10d